Skip to content

Conversation

@airMeng
Copy link
Collaborator

@airMeng airMeng commented Aug 13, 2025

Add chunked prefill op. The PR works with OneAPI 2025.1 currently.

llama-3b BF16 accuracy results, verified on BMG-12GB. You need to install SGLang per instructions from https://github.com/airMeng/sglang/blob/xpu_attention/docs/platforms/xpu.md

- gsm-8k mmlu
intel_xpu 0.680 0.545
triton 0.650 0.546

To reproduce the accuracy results, launch the server first

python3 -m sglang.launch_server  --model /PATH/TO/MODEL  --dtype bfloat16 --tp 1  --trust-remote-code  --mem-fraction-static 0.8 --attention-backend intel_xpu --page-size 128 --port 3000

Run the accuracy scripts in SGLang

cd ~/sglang/benchmark/gsm8k
python3 bench_sglang.py --num-questions 200 --port 3000
cd ../mmlu
python3 bench_sglang.py --nsub 20 --port 3000

The PR can't work with the current open source OneAPI due to an issue of SYCLCompat. You can update your local OneAPI according to the intel/llvm#19673

@airMeng airMeng marked this pull request as draft August 13, 2025 03:12
@airMeng airMeng marked this pull request as ready for review September 9, 2025 05:17
@airMeng airMeng force-pushed the cutlass_attntion branch 2 times, most recently from 4319c23 to 6ad98d8 Compare September 9, 2025 06:07
Comment on lines 175 to 205
if cu_seqlens_q == None: # !is_varlen_q
cu_seqlens_q = torch.arange(
0, q.size(0) + 1, dtype=torch.int, device=q.device
) * q.size(1)
max_seqlen_q = q.size(1)
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new
cu_seqlens_k_new = torch.arange(
0, k.size(0) + 1, dtype=torch.int, device=k.device
)
elif k is None:
cu_seqlens_k_new = torch.zeros_like(
cu_seqlens_q, dtype=torch.int32, device=q.device
)
if cache_seqlens is not None:
max_seqlen_k = cache_seqlens.max().item()
assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0)
max_page_size_per_seq = page_table.size(1)
num_pages_per_seq = torch.arange(
0,
cache_seqlens.size(0) * max_page_size_per_seq,
max_page_size_per_seq,
device=cache_seqlens.device,
).to(torch.int32)
cu_seqlens_k = torch.concat(
(
torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device),
torch.cumsum(cache_seqlens, 0),
)
).to(torch.int32)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

these ops are causing perf degrade compared to triton

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries, we are aware of this. this PR still needs a lot of change.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't have to pay too much attention for it right now, will be fixed later.

@mingfeima mingfeima marked this pull request as draft September 15, 2025 07:44
std::optional<at::Tensor>& q_descale_, // (b, h_k), not (b, h)
std::optional<at::Tensor>& k_descale_, // (b, h_k)
std::optional<at::Tensor>& v_descale_, // (b, h_k)
std::optional<const at::Tensor>& page_table_, // (b_k, max_num_pages_per_seq)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why are we changing function signature ?

Comment on lines +275 to +287
if cu_seqlens_q == None: # !is_varlen_q
cu_seqlens_q = torch.arange(
0, q.size(0) + 1, dtype=torch.int, device=q.device
) * q.size(1)
max_seqlen_q = q.size(1)
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
batch_size = cu_seqlens_q.numel() - 1
page_table = (
torch.arange(0, batch_size, device=q.device)
.to(torch.int32)
.reshape([batch_size, 1])
.contiguous()
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what extra functionality we are trying to provide ?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

current kernel implementation are align between vllm and sglang requests, so there will be some changes on the sglang side.”

#include "cutlass/util/device_memory.h"
#include "cutlass/util/packed_stride.hpp"
#include "cutlass/util/reference/device/gemm_complex.h"
#include "cutlass/util/reference/device/tensor_compare.h"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't need header files of verify code

if (params.page_table != nullptr && params.cu_seqlens_k != nullptr) {
return run<true, true, cutlass::flash_attention::IndividualScheduler>(params);
} else {
return 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

only use page_kv?

CHECK_DEVICE(v_new);
TORCH_CHECK(k_new.stride(-1) == 1, "k_new tensor must have contiguous last dimension");
TORCH_CHECK(v_new.stride(-1) == 1, "v_new tensor must have contiguous last dimension");
int seqlen_k_new = !is_varlen_k_new ? k_new.size(1) : 1;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

seqlen_kv_new =1? or 0

at::Tensor out_accum, softmax_lse_accum;
auto outaccum_type = at::ScalarType::Float;

constexpr int PipelineStages = 0;
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Set 2

Comment on lines 451 to 454
#define CHECK_DEVICE(x) TORCH_CHECK(x.is_xpu(), #x " must be on XPU")
#define CHECK_SHAPE(x, ...) \
TORCH_CHECK(x.sizes() == torch::IntArrayRef({__VA_ARGS__}), #x " must have shape (" #__VA_ARGS__ ")")
#define CHECK_CONTIGUOUS(x) TORCH_CHECK(x.is_contiguous(), #x " must be contiguous")
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

move to utils

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Comment on lines 175 to 205
if cu_seqlens_q == None: # !is_varlen_q
cu_seqlens_q = torch.arange(
0, q.size(0) + 1, dtype=torch.int, device=q.device
) * q.size(1)
max_seqlen_q = q.size(1)
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new
cu_seqlens_k_new = torch.arange(
0, k.size(0) + 1, dtype=torch.int, device=k.device
)
elif k is None:
cu_seqlens_k_new = torch.zeros_like(
cu_seqlens_q, dtype=torch.int32, device=q.device
)
if cache_seqlens is not None:
max_seqlen_k = cache_seqlens.max().item()
assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0)
max_page_size_per_seq = page_table.size(1)
num_pages_per_seq = torch.arange(
0,
cache_seqlens.size(0) * max_page_size_per_seq,
max_page_size_per_seq,
device=cache_seqlens.device,
).to(torch.int32)
cu_seqlens_k = torch.concat(
(
torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device),
torch.cumsum(cache_seqlens, 0),
)
).to(torch.int32)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no worries, we are aware of this. this PR still needs a lot of change.

Comment on lines 175 to 205
if cu_seqlens_q == None: # !is_varlen_q
cu_seqlens_q = torch.arange(
0, q.size(0) + 1, dtype=torch.int, device=q.device
) * q.size(1)
max_seqlen_q = q.size(1)
q = q.view(-1, q.size(-2), q.size(-1)).contiguous()
if cu_seqlens_k_new is None and k is not None: # !is_varlen_k_new
cu_seqlens_k_new = torch.arange(
0, k.size(0) + 1, dtype=torch.int, device=k.device
)
elif k is None:
cu_seqlens_k_new = torch.zeros_like(
cu_seqlens_q, dtype=torch.int32, device=q.device
)
if cache_seqlens is not None:
max_seqlen_k = cache_seqlens.max().item()
assert cache_seqlens.size(0) + 1 == cu_seqlens_q.size(0)
max_page_size_per_seq = page_table.size(1)
num_pages_per_seq = torch.arange(
0,
cache_seqlens.size(0) * max_page_size_per_seq,
max_page_size_per_seq,
device=cache_seqlens.device,
).to(torch.int32)
cu_seqlens_k = torch.concat(
(
torch.zeros(1, dtype=torch.int32, device=cache_seqlens.device),
torch.cumsum(cache_seqlens, 0),
)
).to(torch.int32)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

don't have to pay too much attention for it right now, will be fixed later.

@mingfeima mingfeima marked this pull request as ready for review October 11, 2025 03:00
@mingfeima
Copy link
Collaborator

@airMeng fix lint

@airMeng airMeng merged commit 1b3eb39 into main Oct 11, 2025
3 checks passed
sunjiweiswift added a commit to sunjiweiswift/sgl-kernel-xpu that referenced this pull request Oct 21, 2025
* initialize Cutlass support
Add chunked prefill op

---------

Co-authored-by: Swift.Sun <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants